import os
import pickle
import numpy as np
from utils.data_util import get_seqs


def continuous_cropping(dimer_seq, dist_map, crop_size):
    len1, len2 = [len(x) for x in dimer_seq]
    num_res = len1 + len2
    r = max(len1,len2)/min(len1, len2)
    if r > 2:
        if len1 < len2:
            crop_size1 = min(len1, int(len1/num_res*crop_size*r))
            crop_size2 = crop_size - crop_size1
        else:
            crop_size2 = min(len2, int(len2/num_res*crop_size*r))
            crop_size1 = crop_size - crop_size2
    else: 
        crop_size1 = int(len1/num_res*crop_size)
        crop_size2 = crop_size - crop_size1
    start1 = np.random.randint(0, int(len1-crop_size1)+1)
    end1 = start1 + crop_size1
    start2 = np.random.randint(0, int(len2-crop_size2)+1)
    end2 = start2 + crop_size2
    cr_dimer_seq = [dimer_seq[0][start1: end1], dimer_seq[1][start2: end2]]
    keep_index = np.arange(start1, end1).tolist() + np.arange(len1+start2, len1+end2).tolist()
    cr_dist_map = dist_map[keep_index, :][:, keep_index]
    return cr_dimer_seq, cr_dist_map


def spatial_cropping(seq, dist_map, crop_size, cb_cb_threshold=10):
    """
    seq: [seq1, seq2]
    dist_map: numpy array of shape (N, N)
    """
    # get interface candidates
    len1 = len(seq[0])
    len2 = len(seq[1])
    inter_chain_dist_map = dist_map[:len1, len1:]
    cnt_interfaces = (inter_chain_dist_map < cb_cb_threshold).sum(axis=-1) 
    interface_candidates = np.where(cnt_interfaces != 0)[0]   
    if np.any(interface_candidates):
        prob = cnt_interfaces[interface_candidates] / cnt_interfaces[interface_candidates].sum()
        target_res = int(np.random.choice(interface_candidates, p=prob))   # int
    else:
        return continuous_cropping(seq, dist_map, crop_size)
    
    to_target_distances = dist_map[target_res]   # of shape (N,)

    break_tie = np.arange(0, to_target_distances.shape[-1], dtype=float)*1e-3
    to_target_distances += break_tie
    ret = np.argsort(to_target_distances)[:crop_size]
    ret.sort()
    cr_dist_map = dist_map[ret,:][:,ret]

    # Note that the crop might not be continous
    chain1_index = ret[ret<len1]
    chain2_index = ret[ret>=len1]
    cr_seq1 = ''.join([seq[0][i] for i in chain1_index])
    cr_seq2 = ''.join([seq[1][i-len1] for i in chain2_index])
    cr_dimer_seq = [cr_seq1, cr_seq2]
    return cr_dimer_seq, cr_dist_map



def multi_chain_cropping(seqs, dist_maps, crop_size=220, spatial_crop_prob=0.5, cb_cb_threshold=10, seed=123):
    """
    Args:
        seqs: {name: [chain1, chain2]}
        dist_maps: {name: dist_map}
    """
    cr_dist_maps = {}
    cr_seqs = {}
    np.random.seed(seed)
    for name, dist_map in dist_maps.items():
        seq = seqs[name]
        dimer_seq = [seq['chain1_seq'], seq['chain2_seq']]
        len1, len2 = [len(x) for x in dimer_seq]
        if len1 + len2 <= crop_size:
            cr_seqs[name] = seq
            cr_dist_maps[name] = dist_map
        else:
            use_spatial_crop = np.random.rand() < spatial_crop_prob
            if use_spatial_crop:
                cr_dimer_seq, cr_dist_map = spatial_cropping(dimer_seq, dist_map, crop_size, cb_cb_threshold)
            else:
                cr_dimer_seq, cr_dist_map = continuous_cropping(dimer_seq, dist_map, crop_size)
            len1, len2 = [len(x) for x in cr_dimer_seq]
            cr_seqs[name] = {'dimer_seq': ','.join(cr_dimer_seq), 'chain1_len': len1, 'chain2_len': len2}
            cr_dist_maps[name] = cr_dist_map
    return cr_seqs, cr_dist_maps


def crop(seq_path, dist_path, crop_size=200, spatial_crop_prob=0.5, cb_cb_threshold=10, seed=123):

    sequences = get_seqs(seq_path)
    with open(dist_path, mode='rb') as f:
        dist_maps = pickle.load(f) 
    print(len(dist_maps.keys()))

    cr_seqs, cr_dist_maps = multi_chain_cropping(sequences, dist_maps, 
                                                crop_size=crop_size, 
                                                spatial_crop_prob=spatial_crop_prob,
                                                cb_cb_threshold=cb_cb_threshold,
                                                seed=seed
                                                )

    save_path1 = os.path.join(data_dir, 'train_crop'+str(crop_size)+'_profiling.pickle')
    save_path2 = os.path.join(data_dir, 'train_crop'+str(crop_size)+'_distance_map.pickle')
    with open(save_path1, mode='wb') as f:
        pickle.dump(cr_seqs, f) 
    with open(save_path2, mode='wb') as f:
        pickle.dump(cr_dist_maps, f) 
    